import tensorflow as tf

def scaled_dot_product_attention(q, k, v, mask, mixed, scale, dropout_rate, training):
    """Calculate the attention weights.
    q, k, v must have matching leading dimensions.
    k, v must have matching penultimate dimension, i.e.: seq_len_k = seq_len_v.
    The mask has different shapes depending on its type(padding or look ahead)
    but it must be broadcastable for addition.
    Args:
    q: query shape == (..., seq_len_q, depth)
    k: key shape == (..., seq_len_k, depth)
    v: value shape == (..., seq_len_v, depth_v)
    mask: Float tensor with shape broadcastable
          to (..., seq_len_q, seq_len_k). Defaults to None.
    mixed: Boolean, if use mixed precision.
    Returns:
    output, attention_weights
    """
    matmul_qk = tf.matmul(q, k, transpose_b=True)  # (..., seq_len_q, seq_len_k)
 
    # scale matmul_qk
    if scale:
        if mixed:
            dk = tf.cast(tf.shape(v)[-1], tf.float16)
        else:
            dk = tf.cast(tf.shape(v)[-1], tf.float32)
        scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)
    else:
        scaled_attention_logits = matmul_qk
 
    # add the mask to the scaled tensor.
    if mask is not None:
        if mixed:
            negative_min = tf.float16.min/2
        else:
            negative_min = tf.float32.min/2
        scaled_attention_logits = scaled_attention_logits*mask + negative_min*(1-mask)
 
    # softmax is normalized on the last axis (seq_len_k) so that the scores add up to 1.
    attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1)  # (..., seq_len_q, seq_len_k)
 
    if training:
        attention_weights = tf.nn.dropout(attention_weights, dropout_rate)

    output = tf.matmul(attention_weights, v)  # (..., seq_len_q, depth_v)
 
    return output, attention_weights

class MHA(tf.keras.layers.Layer):
    def __init__(self, d_model, num_heads, initializer, mixed, scale, attn_pdrop, resid_pdrop):
        super(MHA, self).__init__()
        self.num_heads = num_heads
        self.d_model = d_model
        self.attn_pdrop = attn_pdrop
        self.resid_pdrop = resid_pdrop
        self.mixed = mixed
        self.scale = scale
 
        assert d_model % self.num_heads == 0
 
        self.depth = d_model // self.num_heads
 
        self.wq = tf.keras.layers.Dense(d_model, kernel_initializer=initializer)
        self.wk = tf.keras.layers.Dense(d_model, kernel_initializer=initializer)
        self.wv = tf.keras.layers.Dense(d_model, kernel_initializer=initializer)
 
        self.dense = tf.keras.layers.Dense(d_model, kernel_initializer=initializer)
 
    def split_heads(self, x, batch_size):
        """Split the last dimension into (num_heads, depth).
        Transpose the result such that the shape is (batch_size, num_heads, seq_len, depth)
        """
        x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))
        return tf.transpose(x, perm=[0, 2, 1, 3])
 
    def call(self, v, k, q, mask, training):
        batch_size = tf.shape(q)[0]
 
        q = self.wq(q)  # (batch_size, seq_len, d_model)
        k = self.wk(k)  # (batch_size, seq_len, d_model)
        v = self.wv(v)  # (batch_size, seq_len, d_model)
 
        q = self.split_heads(q, batch_size)  # (batch_size, num_heads, seq_len_q, depth)
        k = self.split_heads(k, batch_size)  # (batch_size, num_heads, seq_len_k, depth)
        v = self.split_heads(v, batch_size)  # (batch_size, num_heads, seq_len_v, depth)
 
        # scaled_attention.shape == (batch_size, num_heads, seq_len_q, depth)
        # attention_weights.shape == (batch_size, num_heads, seq_len_q, seq_len_k)
        scaled_attention, attention_weights = scaled_dot_product_attention(
            q, k, v, mask, self.mixed, self.scale, self.attn_pdrop, training)
 
        scaled_attention = tf.transpose(scaled_attention, perm=[0, 2, 1, 3])  # (batch_size, seq_len_q, num_heads, depth)
 
        concat_attention = tf.reshape(scaled_attention,
                                      (batch_size, -1, self.d_model))  # (batch_size, seq_len_q, d_model)
 
        output = self.dense(concat_attention)  # (batch_size, seq_len_q, d_model)
        if training:
            output = tf.nn.dropout(output, self.resid_pdrop)
 
        return output, attention_weights

def point_wise_feed_forward_network(d_model, dff):
    return tf.keras.Sequential([
      tf.keras.layers.Dense(dff, activation='gelu'),  # (batch_size, seq_len, dff)
      tf.keras.layers.Dense(d_model)  # (batch_size, seq_len, d_model)
    ])

class DecoderLayer(tf.keras.layers.Layer):
    def __init__(self, d_model, num_heads, dff, initializer, mixed, scale, attn_pdrop, resid_pdrop):
        super(DecoderLayer, self).__init__()
        self.mha = MHA(
            d_model=d_model, 
            num_heads=num_heads, 
            initializer=initializer, 
            mixed=mixed, 
            scale=scale,
            attn_pdrop=attn_pdrop, 
            resid_pdrop=resid_pdrop)

        self.ffn = point_wise_feed_forward_network(d_model, dff)
 
        self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
        self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
 
        #self.dropout1 = tf.keras.layers.Dropout(rate)
        self.dropout = tf.keras.layers.Dropout(resid_pdrop)
 
    def call(self, x, training, look_ahead_mask):
        attn, attn_weights_block = self.mha(x, x, x, look_ahead_mask, training)  # (batch_size, target_seq_len, d_model)
        #attn = self.dropout1(attn, training=training)
        out1 = self.layernorm1(attn + x)
 
        ffn_output = self.ffn(out1)  # (batch_size, target_seq_len, d_model)
        ffn_output = self.dropout(ffn_output, training=training)
        out2 = self.layernorm2(ffn_output + out1)  # (batch_size, target_seq_len, d_model)
 
        return out2, attn_weights_block
    
class Decoder(tf.keras.layers.Layer):
    def __init__(self, num_layers, d_model, num_heads, dff, target_vocab_size, block_size, initializer=tf.random_normal_initializer(stddev=0.02), mixed=True, scale=True, attn_pdrop=0.1, resid_pdrop=0.1, embed_pdrop=0.1):
        super(Decoder, self).__init__()
 
        self.d_model = d_model
        self.num_layers = num_layers
 
        self.embedding = tf.keras.layers.Embedding(target_vocab_size, d_model)
        self.pos_encoding = tf.reshape(tf.range(target_vocab_size-block_size, target_vocab_size), shape=[1, -1])
 
        self.dec_layers = [
            DecoderLayer(d_model, num_heads, dff, initializer, mixed, scale, attn_pdrop, resid_pdrop)
            for _ in range(num_layers)]
        self.dropout = tf.keras.layers.Dropout(embed_pdrop)

        self.mixed = mixed
 
    def call(self, x, training, look_ahead_mask):
        #seq_len = tf.shape(x)[1]
        attention_weights = {}
 
        x = self.embedding(x)  # (batch_size, block_size, d_model)
        '''
        if self.mixed:
            x *= tf.math.sqrt(tf.cast(self.d_model, tf.float16))
        else:
            x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32))
        '''
        x += self.embedding(self.pos_encoding)
        x = self.dropout(x, training=training)
 
        for i in range(self.num_layers):
            x, block1 = self.dec_layers[i](x, training, look_ahead_mask)
            attention_weights[f'decoder_layer{i+1}_block1'] = block1
 
        # x.shape == (batch_size, target_seq_len, d_model)
        return x, attention_weights
    
class GPT1(tf.keras.Model):
    def __init__(self, num_layers, d_model, num_heads, dff, target_vocab_size, block_size, initializer, mixed, scale, attn_pdrop, resid_pdrop, embed_pdrop):
        super().__init__()
 
        self.decoder = Decoder(num_layers, d_model, num_heads, dff, target_vocab_size, block_size, initializer, mixed, scale, attn_pdrop, resid_pdrop, embed_pdrop)
        self.final_layer = tf.keras.layers.Dense(target_vocab_size-block_size)
        self.mixed = mixed
 
    def call(self, inp, training):
        # Keras models prefer if you pass all your inputs in the first argument
        look_ahead_mask = self.create_masks(inp)
        dec_output, attention_weights = self.decoder(inp, training, look_ahead_mask)
        final_output = self.final_layer(dec_output)  # (batch_size, tar_seq_len, target_vocab_size-block_size)
        if self.mixed:
            final_output = tf.keras.layers.Activation('linear', dtype='float32')(final_output)
        return final_output, attention_weights
 
    def create_masks(self, tar):
        # Used in the 1st attention block in the decoder.
        # It is used to pad and mask future tokens in the input received by
        # the decoder.
        size = tf.shape(tar)[1]
        if self.mixed:
            look_ahead_mask = tf.cast(tf.linalg.band_part(tf.ones((size, size)), -1, 0), dtype='float16')
        else:
            look_ahead_mask = tf.cast(tf.linalg.band_part(tf.ones((size, size)), -1, 0), dtype='float32')
        return look_ahead_mask